function [benchmark_1, benchmark_2, scaled_ATT, ATE] = calculate_benchmark_savings(sample_cross,wave,cost,SMC,bsimax)

% This function generates metrics used to benchmark savings from EWM rules.
% Specifically, the two metrics are scaled ATT, ATE of the actual RCT.  

% Input 
% (1) sample_cross: defined as the cross-section of household consumption,
% covariates, and propensity scores. 
% (2) wave: =0 if pooled sample, or 3, 6, 7 for wave-specific results
% (2) cost: program cost per household. =1 represents cost savings; =0
% represents kWh reduction
% (3) SMC: =1 use social marginal cost; =0 use retail electricity price
% (4) bsimax: number of bootstrap runs

% Output 
% (1) benchmark_1: scaled ATT with CI in a table 
% (1) benchmark_2: ATE with CI in a table 
% (1) scaled_ATT: numberic 
% (1) ATE: numberic

%% Specify cost input 
if (cost)
    tcost = 0.765; % cost of sending out letter for each household. see footnote in paper for detailed calculation.
else
    tcost = 0;
end

%% Specify output table 
savings_benchmark = nan(1,11);
savings_benchmark = array2table(savings_benchmark,'VariableNames',{'wave','percent','scaled_ATT','ATE','number_hh','total_savings_ATT','total_savings_ATE','ci_lb_ATT','ci_ub_ATT','ci_lb_ATE','ci_ub_ATE'}); 

%% set random seed 
rng(0);

%% Define variables 
    if wave==0
        % define pooled variables 
        usage_wave = sample_cross.post_avg - sample_cross.pre_avg;
        opower_wave = sample_cross.opower;
    else
        % define wave-specific variables 
        usage_wave = table2array(sample_cross(sample_cross.opower_paper_wave==wave,{'post_avg'}))-...
            table2array(sample_cross(sample_cross.opower_paper_wave==wave,{'pre_avg'}));
        opower_wave = table2array(sample_cross(sample_cross.opower_paper_wave==wave,{'opower'}));
    end 

    if (tcost)
        if (SMC)
            Y_all = 0.065.*usage_wave(:,1)+tcost.*opower_wave(:,1);
        else 
            Y_all = 0.177.*usage_wave(:,1)+tcost.*opower_wave(:,1);
        end
    else
        Y_all = usage_wave(:,1);
    end

    if wave==0 % in the pooled sample
        demean_wave = 1;  %=1 demean within wave, =0 demean by full sample
        [Y_all] = Y_demean (Y_all,sample_cross,demean_wave);
    else % wave-specific analysis, always demean by "full" wave sample
        demean_wave = 0;  % DO NOT CHANGE
        [Y_all] = Y_demean (Y_all,sample_cross,demean_wave);
    end

    data = [Y_all opower_wave];
    Y = data(:,1);
    D = data(:,2);
    n = length(Y); % sample size

    if wave==0
        ps = sample_cross.ps; % individual-specific propensity score in the pooled sample
    else
        ps = mean(D); % propensity score is constant across sample within each wave 
    end

    %% Calculate g using IPW formula
    g = Y.*(D./ps - (1-D)./(1-ps));

    %% Calculate different savings metrics as benchmarks for EWM results
    ATT = sum(D.*Y-((1-D).*Y.*ps)./(1-ps))/sum(D);
    ATE = mean(g); % equiv. to ATE = mean(D.*Y./ps)-mean(((1-D).*Y)./(1-ps));

    %% Calculate 95% CI
    bsd_att = zeros(bsimax,1);
    bsd_ate = zeros(bsimax,1);
    for bsi = 1:bsimax
        bsperm = randi(n,[n 1]);
        gr = g(bsperm,:);
        Dr = D(bsperm,:);
        Yr = Y(bsperm,:);
        if wave==0 
            psr = ps(bsperm,:);
        else 
            psr = ps; % if wave-specific, ps is constant number, not array 
        end
        bsd_att(bsi) = abs((sum(Dr.*Yr-((1-Dr).*Yr.*psr)./(1-psr))/size(Yr(Dr==1,:),1)) - ATT); 
        bsd_ate(bsi) = abs(mean(gr) - ATE); 
    end

    ci_lb_ATT = ATT - prctile(bsd_att,95);
    ci_ub_ATT = ATT + prctile(bsd_att,95);
    
    ci_lb_ATE = ATE - prctile(bsd_ate,95);
    ci_ub_ATE = ATE + prctile(bsd_ate,95);
    
%% sumarize cost savings 
percent = mean(D);
scaled_ATT = ATT*(sum(D))/n;

benchmark_1=nan(1,5);
benchmark_1=array2table(benchmark_1,'VariableNames',...
    {'Rules','Covariates','Share\ treated','Savings\ from\ EWM\ rules','Difference\ in\ savings\ between\ EWM\ RCT'}); 
format short g
benchmark_1.(1) ={'Actual RCT'};
benchmark_1.(2) ={'Scaled ATT'};
benchmark_1.(3) =round(percent*100);
benchmark_1.(4) =round(scaled_ATT,2);
benchmark_1.(5) ="";

ci_lb_att=round(ci_lb_ATT,2);
ci_ub_att=round(ci_ub_ATT,2);

ci_table_att=nan(1,5);
ci_table_att=array2table(ci_table_att,'VariableNames',...
    {'Rules','Covariates','Share\ treated','Savings\ from\ EWM\ rules','Difference\ in\ savings\ between\ EWM\ RCT'}); 
cilb_string_att=string(ci_lb_att);
ciub_string_att=string(ci_ub_att);
ci_for_table_att=strcat('(',cilb_string_att,',',ciub_string_att,')');
ci_table_att.(1)="";
ci_table_att.(2)="";
ci_table_att.(3)="";
ci_table_att.(4)=ci_for_table_att;
ci_table_att.(5)="";
benchmark_1=[benchmark_1;ci_table_att];

benchmark_2=nan(1,5);
benchmark_2=array2table(benchmark_2,'VariableNames',...
    {'Rules','Covariates','Share\ treated','Savings\ from\ EWM\ rules','Difference\ in\ savings\ between\ EWM\ RCT'}); 
format short g
benchmark_2.(1) ={'Universal treatment'};
benchmark_2.(2) ={'ATE'};
benchmark_2.(3) =100;
benchmark_2.(4) =round(ATE,2);
benchmark_2.(5) ="";

ci_lb_ate=round(ci_lb_ATE,2);
ci_ub_ate=round(ci_ub_ATE,2);

ci_table_ate=nan(1,5);
ci_table_ate=array2table(ci_table_ate,'VariableNames',...
    {'Rules','Covariates','Share\ treated','Savings\ from\ EWM\ rules','Difference\ in\ savings\ between\ EWM\ RCT'}); 
cilb_string_ate=string(ci_lb_ate);
ciub_string_ate=string(ci_ub_ate);
ci_for_table_ate=strcat('(',cilb_string_ate,',',ciub_string_ate,')');
ci_table_ate.(1)="";
ci_table_ate.(2)="";
ci_table_ate.(3)="";
ci_table_ate.(4)=ci_for_table_ate;
ci_table_ate.(5)="";
benchmark_2=[benchmark_2;ci_table_ate];

end
